Unverified Commit 9465b668 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Allow running with vllm==0.4.3 (#561)

parent 05471f21
import json import json
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
from outlines.caching import cache as disk_cache
from outlines.caching import disable_cache
from outlines.fsm.guide import RegexGuide
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
from outlines.models.transformers import TransformerTokenizer
from pydantic import BaseModel from pydantic import BaseModel
try:
from outlines.caching import cache as disk_cache
from outlines.fsm.guide import RegexGuide
from outlines.caching import disable_cache
from outlines.fsm.guide import RegexGuide
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
from outlines.models.transformers import TransformerTokenizer
except ImportError as e:
print(f'\nError: {e}. Please install a new version of outlines by `pip install "outlines>=0.0.44"`\n')
raise
try: try:
from outlines.fsm.json_schema import build_regex_from_object from outlines.fsm.json_schema import build_regex_from_object
except ImportError: except ImportError:
......
...@@ -512,8 +512,13 @@ def fused_moe( ...@@ -512,8 +512,13 @@ def fused_moe(
# Check constraints. # Check constraints.
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
if hasattr(ops, "topk_softmax"):
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize) renormalize)
else:
topk_weights, topk_ids = fused_topk_v0_4_3(hidden_states, gating_output, topk,
renormalize)
return fused_experts(hidden_states, return fused_experts(hidden_states,
w1, w1,
w2, w2,
...@@ -526,3 +531,33 @@ def fused_moe( ...@@ -526,3 +531,33 @@ def fused_moe(
w2_scale=w2_scale, w2_scale=w2_scale,
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale) a2_scale=a2_scale)
def fused_topk_v0_4_3(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
):
import vllm._moe_C as moe_kernels
M, _ = hidden_states.shape
topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
token_expert_indicies = torch.empty(
M, topk, dtype=torch.int32, device=hidden_states.device
)
moe_kernels.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(), # TODO(woosuk): Optimize this.
)
del token_expert_indicies # Not used. Will be used in the future.
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment