Commit 6b3bb3ae authored by chenhw5's avatar chenhw5
Browse files

sbo-deepep-gemm based on v0.9.2-dev-0316-dp

parent 236266a9
......@@ -203,6 +203,7 @@ if TYPE_CHECKING:
VLLM_USE_FUSED_RMS_ROPE: bool = False
VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False
VLLM_USE_FUSED_FILL_RMS_CAT:bool = False
VLLM_EP_USE_SBO: bool = False
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM: bool = True
VLLM_ENABLE_DEEPEP_INT8_DISPATCH: bool = True
VLLM_ZERO_OVERHEAD_ENHANCE: bool = False
......@@ -1332,7 +1333,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER":
lambda: (os.getenv("VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER",
"0").lower() in ("true", "1")),
# Whether to use single batch overlapping optimization
"VLLM_EP_USE_SBO": lambda: bool(int(os.getenv("VLLM_EP_USE_SBO", "0"))),
# vLLM will use deepgemm kernel for deepep ht mode
"VLLM_ENABLE_DEEPEP_HT_DEEPGEMM":
lambda: (os.getenv('VLLM_ENABLE_DEEPEP_HT_DEEPGEMM', '1').lower() in
......
......@@ -10,6 +10,17 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.triton_utils import tl, triton
from vllm.utils.import_utils import has_deep_gemm
from lightop import fuse_silu_mul_quant_ep, fuse_silu_mul_fp8_quant_ep
if has_deep_gemm():
from deep_gemm import m_grouped_w8a8_gemm_nt_masked, m_grouped_fp8_gemm_nt_masked
else:
from lightop import m_grouped_w8a8_gemm_nt_masked, m_grouped_fp8_gemm_nt_masked
from typing import Any
logger = init_logger(__name__)
......@@ -261,6 +272,8 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
w2_gemm_overlap_args: Any = None,
meta_overlap_args: dict[str, Any] | None = None,
):
import deep_gemm as dg
assert hidden_states.ndim == 3
......@@ -281,18 +294,39 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# may lead to better performance.
expected_m = max_num_tokens
dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a1q, a1q_scale),
(w1, w1_scale),
out=workspace1,
masked_m=expert_num_tokens,
expected_m=expected_m)
m_grouped_fp8_gemm_nt_masked(
(a1q, a1q_scale),
(w1, w1_scale),
workspace1,
expert_num_tokens,
expected_m,
)
assert expert_num_tokens is not None
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1,
expert_num_tokens)
dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a2q, a2q_scale),
(w2, w2_scale),
out=output,
masked_m=expert_num_tokens,
expected_m=expected_m)
a2q, a2q_scale = fuse_silu_mul_quant_ep(workspace1, expert_num_tokens)
# When using SBO, we record event here to indicate
# that the signal tensor and input to deepep ll combine
# are ready
enable_overlap = w2_gemm_overlap_args is not None
signal = w2_gemm_overlap_args.signal if enable_overlap else None
if enable_overlap:
w2_gemm_overlap_args.start_event.record()
block_m, threshold = m_grouped_w8a8_gemm_nt_masked((a2q, a2q_scale),
(w2, w2_scale),
output,
expert_num_tokens,
expected_m,
enable_overlap,
signal,
)
# return meta_overlap_args to DeepEP combine.
if meta_overlap_args is not None:
if block_m is not None:
meta_overlap_args["block_m"] = block_m
if threshold is not None:
meta_overlap_args["threshold"] = threshold
......@@ -100,6 +100,7 @@ class DeepEPAutoPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
local_num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
......
......@@ -245,6 +245,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
local_num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
......
......@@ -15,6 +15,47 @@ from vllm.model_executor.layers.fused_moe.utils import (
DEEPEP_QUANT_BLOCK_SIZE = 128
DEEPEP_QUANT_BLOCK_SHAPE = [DEEPEP_QUANT_BLOCK_SIZE, DEEPEP_QUANT_BLOCK_SIZE]
from contextlib import nullcontext
from dataclasses import dataclass
from typing import Any
alt_stream = torch.cuda.Stream()
from vllm import envs
@dataclass
class CombineOverlapArgs:
# Whether to use overlap for w2 gemm and deepep ll combine
overlap: bool
# we launch deepep ll combine on this stream, which
# is different from the default compute stream
stream: torch.cuda.Stream
# We record this wait even in the compute stream between
# silu_mul_fp4_quantize and w2 gemm.
# And we wait for this even before deepep ll combine on the
# combine stream to ensure signal tensors have been allocated
wait_event: torch.cuda.Event
# Number of CU used for combine kernel, currently hardcoded to be 32
num_sms: int
# The signal tensor is shared by the w2 gemm and deepep ll combine.
# w2 gemm atomic_add to the tensor to signal deepep combine can start
# send data
signal: torch.Tensor | None = None
#
block_m: int = 64
# Set to the number of CU used by W2 gemm, which is a persistent kernel
# So when all CU has completed the computation for an expert,
# combine kernel can start to send data for this expert
threshold: int = 32
@dataclass
class W2GemmOverlapArgs:
# Number of CU used by W2 gemm
num_sms: int
# Same signal tensor mentioned above
signal: torch.Tensor
# Same as the wait_even in CombineOverlapArgs
start_event: torch.cuda.Event
def dequant_fp8(expert_x_fp8: torch.Tensor,
expert_x_scales: torch.Tensor) -> torch.Tensor:
......@@ -59,6 +100,11 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
self.handle = None
self.num_dispatchers_ = num_dispatchers
# SBO: DeepEP LL overlap 配置
self.combine_overlap_args: CombineOverlapArgs | None = None
self.meta_overlap_args: dict[str, Any] | None = None
self.packed_recv_count: torch.Tensor | None = None
def num_dispatchers(self) -> int:
return self.num_dispatchers_
......@@ -118,6 +164,14 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
x_scales = normalize_batched_scales_shape(x_scales, num_experts)
return x, x_scales
def supports_async(self) -> bool:
"""
Indicates whether or not this class implements prepare_async and
finalize_async.
"""
return True
def prepare_async(
self,
......@@ -127,6 +181,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
local_num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
......@@ -166,6 +221,10 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return_recv_hook=True,
)
# We need to pass w2_gemm_overlap_args to moe implementation,
# so return it as an output paramter
w2_gemm_overlap_args = self._create_sbo_args(local_num_experts, a1.device)
return (
hook,
lambda: self._receiver(
......@@ -175,6 +234,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
a1.dtype,
quant_config,
),
w2_gemm_overlap_args,
)
def _receiver(
......@@ -194,6 +254,9 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None
)
# SBO: save last expert_num_tokens for packed_recv_count,needed by deepep ll combine when use SBO.
self.packed_recv_count = expert_num_tokens
return expert_x, expert_x_scale, expert_tokens_meta, None, None
def prepare(
......@@ -204,17 +267,19 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
local_num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
hook, receiver = self.prepare_async(
hook, receiver, _ = self.prepare_async(
a1,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
num_experts,
local_num_experts,
expert_map,
apply_router_weight_on_input,
quant_config,
......@@ -241,18 +306,45 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
combine_topk_weights = torch.ones_like(topk_weights)
# TODO (varun) : Enable zero copy mode
_, _, recv_hook = self.buffer.low_latency_combine(
fused_expert_output,
topk_ids,
combine_topk_weights,
self.handle,
async_finish=False,
zero_copy=False,
return_recv_hook=do_recv_hook,
out=output,
)
ctx = nullcontext()
if self.combine_overlap_args is not None:
# For SBO, we need to wait for compute stream
# to have completed signal tensor allocation
self.combine_overlap_args.stream.wait_event(
self.combine_overlap_args.wait_event
)
# And we launch ll combine phase 1 in a separate stream
# for overlaping
ctx = torch.cuda.stream(self.combine_overlap_args.stream)
overlap_args_dict = dict(
overlap=self.combine_overlap_args.overlap,
packed_recv_count=self.packed_recv_count,
comp_signal=self.combine_overlap_args.signal,
block_m=self.meta_overlap_args["block_m"],
threshold=self.meta_overlap_args["threshold"],
num_sms=self.combine_overlap_args.num_sms,
)
else:
overlap_args_dict = {}
with ctx:
_, _, recv_hook = self.buffer.low_latency_combine(
fused_expert_output,
combine_topk_ids,
combine_topk_weights,
handle,
async_finish=False,
zero_copy=False,
return_recv_hook=do_recv_hook,
out=output,
**overlap_args_dict,
)
if self.combine_overlap_args is not None:
return recv_hook, lambda: self._sbo_wait_stream()
else:
return recv_hook, lambda: None
return recv_hook
def finalize_async(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
......@@ -283,3 +375,48 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_weights_and_reduce,
do_async=False,
)
def _create_sbo_args(
self, local_num_experts: int, device: torch.device
) -> W2GemmOverlapArgs | None: # None when SBO is not enabled
w2_gemm_overlap_args = None
self.combine_overlap_args = None
if not envs.VLLM_EP_USE_SBO:
self.meta_overlap_args = None
return None
else:# SBO enabled
self.meta_overlap_args = {} # empty every time, avoid use history args.
total_num_sms = torch.cuda.get_device_properties(
device=device
).multi_processor_count
communicate_num_sms = 32
compute_num_sms = total_num_sms - communicate_num_sms
combine_wait_event = torch.cuda.Event()
combine_overlap_args = CombineOverlapArgs(
num_sms=communicate_num_sms,
stream=alt_stream,
wait_event=combine_wait_event,
)
combine_signal = torch.zeros(
local_num_experts, dtype=torch.uint32, device=device
)
w2_gemm_overlap_args = W2GemmOverlapArgs(
signal=combine_signal,
start_event=combine_wait_event,
num_sms=compute_num_sms,
)
combine_overlap_args.overlap = True
combine_overlap_args.signal = combine_signal
self.combine_overlap_args = combine_overlap_args
return w2_gemm_overlap_args
def _sbo_wait_stream(self) -> None:
# When SBO enabled, ll combine phase 2 is still launched
# on the main compute stream, but we need to wait for
# ll combine 1 to complete
if self.combine_overlap_args is not None:
torch.cuda.current_stream().wait_stream(self.combine_overlap_args.stream)
......@@ -502,6 +502,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
local_num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
......
......@@ -15,6 +15,9 @@ import vllm.envs as envs
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.utils import cdiv
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts,
)
#
# This file defines a set of base classes used to make MoE kernels more modular.
......@@ -163,6 +166,7 @@ class FusedMoEPrepareAndFinalize(ABC):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
local_num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
......@@ -192,6 +196,14 @@ class FusedMoEPrepareAndFinalize(ABC):
- Optional dispatched expert topk weight
"""
raise NotImplementedError
def supports_async(self) -> bool:
"""
Indicates whether or not this class implements prepare_async and
finalize_async.
"""
return False
@abstractmethod
def finalize(
......@@ -610,6 +622,184 @@ class FusedMoEModularKernel(torch.nn.Module):
f"{prepare_finalize.activation_format} == "
f"{fused_experts.__class__.__name__}."
f"{fused_experts.activation_formats[0]}")
if self.shared_experts is not None:
self.alt_stream = alt_stream()
self.alt_event = torch.cuda.Event()
def _prepare(
self,
hidden_states: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
global_num_experts: int,
local_num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
) -> tuple[
torch.Tensor,
torch.Tensor | None,
ExpertTokensMetadata | None,
torch.Tensor,
torch.Tensor,
object | None,
]:
"""
The _prepare method is a wrapper around self.prepare_finalize.prepare
that handles DBO and async.
"""
w2_gemm_overlap_args = None
if not self.prepare_finalize.supports_async():
# We shouldn't be running an a2a kernel that doesn't
# support async prepare/finalize
# TODO(lucas): enable in follow-up
#assert not dbo_enabled()
(
a1q,
a1q_scale,
expert_tokens_meta,
_expert_topk_ids,
_expert_topk_weights,
) = self.prepare_finalize.prepare(
hidden_states,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
global_num_experts,
local_num_experts,
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
)
else:
# Overlap shared expert compute with all2all dispatch.
#dbo_maybe_run_recv_hook()
prepare_ret = self.prepare_finalize.prepare_async(
hidden_states,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
global_num_experts,
local_num_experts,
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
)
# TODO(lucas): refactor this in the alternative schedules followup
# currently unpack if we have hook + receiver pair or just
# receiver (see finalize_async docstring)
hook, receiver, w2_gemm_overlap_args = (
prepare_ret if isinstance(prepare_ret, tuple) else (None, prepare_ret)
)
if hook is not None:
# if dbo_enabled():
# # If DBO is being used, register the hook with the ubatch
# # context and call it in dbo_maybe_run_recv_hook instead of
# # passing it to the receiver.
# dbo_register_recv_hook(hook)
# dbo_yield()
# else:
hook()
(
a1q,
a1q_scale,
expert_tokens_meta,
_expert_topk_ids,
_expert_topk_weights,
) = receiver()
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
topk_weights = (
topk_weights if _expert_topk_weights is None else _expert_topk_weights
)
return a1q, a1q_scale, expert_tokens_meta, topk_ids, topk_weights, w2_gemm_overlap_args
def _finalize(
self,
output: torch.Tensor,
fused_out: torch.Tensor,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
The _finalize method is a wrapper around self.prepare_finalize.finalize
that handles DBO, async and shared expert overlap.
"""
shared_output: torch.Tensor | None = None
if not self.prepare_finalize.supports_async():
#assert not dbo_enabled()
self.prepare_finalize.finalize(
output,
fused_out,
topk_weights,
topk_ids,
apply_router_weight_on_input,
)
if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
else:
self.alt_event.record()
if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
current_stream = torch.cuda.current_stream()
with torch.cuda.stream(self.alt_stream):
self.alt_stream.wait_event(self.alt_event)
finalize_ret = self.prepare_finalize.finalize_async(
output,
fused_out,
topk_weights,
topk_ids,
apply_router_weight_on_input,
)
# TODO(lucas): refactor this in the alternative schedules followup
# currently unpack if we have hook + receiver pair or just
# receiver (see finalize_async docstring)
hook, receiver = (
finalize_ret
if isinstance(finalize_ret, tuple)
else (None, finalize_ret)
)
if hook is not None:
# if dbo_enabled():
# # If DBO is being used, register the hook with the ubatch
# # context and call it in dbo_maybe_run_recv_hook instead of
# # passing it to the receiver.
# dbo_register_recv_hook(hook)
# dbo_yield()
# else:
hook()
receiver()
self.alt_event.record()
current_stream.wait_event(self.alt_event)
if self.shared_experts is None:
return output
else:
assert shared_output is not None
return shared_output, output
def forward(
self,
......@@ -674,13 +864,14 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts = local_num_experts
(a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
_expert_topk_weights) = self.prepare_finalize.prepare(
_expert_topk_weights, w2_gemm_overlap_args) = self._prepare(
a1,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
global_num_experts,
local_num_experts,
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
......@@ -739,26 +930,48 @@ class FusedMoEModularKernel(torch.nn.Module):
if num_chunks == 1:
fused_out = _resize_cache(workspace13, fused_out_shape)
self.fused_experts.apply(
fused_out,
a1q,
w1,
w2,
topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_num_tokens=expert_num_tokens,
)
if isinstance(self.fused_experts, BatchedDeepGemmExperts):
self.fused_experts.apply(
fused_out,
a1q,
w1,
w2,
topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_num_tokens=expert_num_tokens,
w2_gemm_overlap_args=w2_gemm_overlap_args,
meta_overlap_args=self.prepare_finalize.meta_overlap_args,
)
else:
self.fused_experts.apply(
fused_out,
a1q,
w1,
w2,
topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_num_tokens=expert_num_tokens,
)
else:
# The leading output dimension may not be equal to M, so
# we compute output indices separately.
......@@ -786,28 +999,50 @@ class FusedMoEModularKernel(torch.nn.Module):
curr_a2_scale = _chunk_scales(a2_scale, begin_chunk_idx,
end_chunk_idx)
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
if isinstance(self.fused_experts, BatchedDeepGemmExperts):
self.fused_experts.apply(
fused_out[begin_out_idx:end_out_idx],
curr_a1q,
w1,
w2,
curr_topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=curr_a1q_scale,
a2_scale=curr_a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_num_tokens=expert_num_tokens,
w2_gemm_overlap_args=w2_gemm_overlap_args,
meta_overlap_args=self.prepare_finalize.meta_overlap_args,
)
else:
self.fused_experts.apply(
fused_out[begin_out_idx:end_out_idx],
curr_a1q,
w1,
w2,
curr_topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=curr_a1q_scale,
a2_scale=curr_a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_num_tokens=expert_num_tokens,
)
self.fused_experts.apply(
fused_out[begin_out_idx:end_out_idx],
curr_a1q,
w1,
w2,
curr_topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=curr_a1q_scale,
a2_scale=curr_a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_num_tokens=expert_num_tokens,
)
self.prepare_finalize.finalize(output, fused_out, topk_weights,
self._finalize(output, fused_out, hidden_states, topk_weights,
topk_ids, apply_router_weight_on_input)
return output
......
......@@ -91,6 +91,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
local_num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
......
......@@ -35,6 +35,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
local_num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
try:
from ._version import __version__, __version_tuple__
__version__ = "0.9.2"
__version_tuple__ = (0, 9, 2)
__hcu_version__ = f'0.9.2+das.opt6.dtk25044'
from vllm.version import __version__, __version_tuple__, __hcu_version__
except Exception as e:
import warnings
warnings.warn(f"Failed to read commit hash:\n{e}",
warnings.warn(f"Failed to read commit hash:\n + str(e)",
RuntimeWarning,
stacklevel=2)
__version__ = "dev"
__version_tuple__ = (0, 0, __version__)
def _prev_minor_version_was(version_str):
"""Check whether a given version matches the previous minor version.
'''Check whether a given version matches the previous minor version.
Return True if version_str matches the previous minor version.
......@@ -23,19 +24,19 @@ def _prev_minor_version_was(version_str):
supplied version_str is '0.6'.
Used for --show-hidden-metrics-for-version.
"""
'''
# Match anything if this is a dev tree
if __version_tuple__[0:2] == (0, 0):
return True
# Note - this won't do the right thing when we release 1.0!
assert __version_tuple__[0] == 0
# assert __version_tuple__[0] == 0
assert isinstance(__version_tuple__[1], int)
return version_str == f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}"
def _prev_minor_version():
"""For the purpose of testing, return a previous minor version number."""
'''For the purpose of testing, return a previous minor version number.'''
# In dev tree, this will return "0.-1", but that will work fine"
assert isinstance(__version_tuple__[1], int)
return f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment