"docs/vscode:/vscode.git/clone" did not exist on "2c1bd848a668787082c0a9364d96db13a9201baa"
Unverified Commit e7b68f4d authored by Xin Yang's avatar Xin Yang Committed by GitHub
Browse files

[Bugfix] Fix Triton FusedMoE LoRA (#30585)


Signed-off-by: default avatarXin Yang <xyangx@amazon.com>
parent 1a19e9cd
......@@ -167,6 +167,7 @@ depthwise_seperable_CNN = "depthwise_seperable_CNN"
[tool.typos.default.extend-words]
iy = "iy"
tendencias = "tendencias"
indx = "indx"
# intel cpu features
tme = "tme"
dout = "dout"
......
......@@ -69,41 +69,54 @@ def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
assert generated_texts[i].startswith(EXPECTED_LORA_OUTPUT[i])
def test_gpt_oss_lora(gptoss20b_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
max_lora_rank=8,
max_num_seqs=2,
max_num_batched_tokens=2048,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False,
),
)
generate_and_test(llm, gptoss20b_lora_files, lora_id=1)
generate_and_test(llm, gptoss20b_lora_files, lora_id=2)
@pytest.mark.parametrize("mxfp4_use_marlin", [True, False])
def test_gpt_oss_lora(
monkeypatch: pytest.MonkeyPatch, gptoss20b_lora_files, mxfp4_use_marlin
):
with monkeypatch.context() as m:
m.setenv("VLLM_MXFP4_USE_MARLIN", "1" if mxfp4_use_marlin else "0")
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
max_lora_rank=8,
max_num_seqs=2,
max_num_batched_tokens=2048,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False,
),
)
generate_and_test(llm, gptoss20b_lora_files, lora_id=1)
generate_and_test(llm, gptoss20b_lora_files, lora_id=2)
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("fully_sharded_loras", [False, True])
def test_gpt_oss_lora_tp2(gptoss20b_lora_files, fully_sharded_loras):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=2,
max_num_seqs=2,
max_num_batched_tokens=2048,
tensor_parallel_size=2,
gpu_memory_utilization=0.8,
fully_sharded_loras=fully_sharded_loras,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False,
),
)
generate_and_test(llm, gptoss20b_lora_files, lora_id=1)
generate_and_test(llm, gptoss20b_lora_files, lora_id=2)
@pytest.mark.parametrize("mxfp4_use_marlin", [True, False])
def test_gpt_oss_lora_tp2(
monkeypatch: pytest.MonkeyPatch,
gptoss20b_lora_files,
fully_sharded_loras,
mxfp4_use_marlin,
):
with monkeypatch.context() as m:
m.setenv("VLLM_MXFP4_USE_MARLIN", "1" if mxfp4_use_marlin else "0")
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=2,
max_num_seqs=2,
max_num_batched_tokens=2048,
tensor_parallel_size=2,
gpu_memory_utilization=0.8,
fully_sharded_loras=fully_sharded_loras,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False,
),
)
generate_and_test(llm, gptoss20b_lora_files, lora_id=1)
generate_and_test(llm, gptoss20b_lora_files, lora_id=2)
......@@ -502,16 +502,18 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
)
self.activation(
activation, intermediate_cache2, intermediate_cache1.view(-1, N)
activation,
intermediate_cache2,
intermediate_cache1.view(-1, N)[gather_indx.dst_indx],
)
# matmul_ogs grouped reduction fuse sum across multiple experts:
# y[dst_ind // n_expts_act, :] += x[src_ind, :]
# y[dst_indx // n_expts_act, :] += x
# Need to set n_expts_act to 1 to unfuse moe_sum
routing_data.n_expts_act = 1
matmul_ogs(
intermediate_cache2,
intermediate_cache2[gather_indx.src_indx],
w2,
self.quant_config.w2_bias,
routing_data,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment