Unverified Commit e7b68f4d authored by Xin Yang's avatar Xin Yang Committed by GitHub
Browse files

[Bugfix] Fix Triton FusedMoE LoRA (#30585)


Signed-off-by: default avatarXin Yang <xyangx@amazon.com>
parent 1a19e9cd
...@@ -167,6 +167,7 @@ depthwise_seperable_CNN = "depthwise_seperable_CNN" ...@@ -167,6 +167,7 @@ depthwise_seperable_CNN = "depthwise_seperable_CNN"
[tool.typos.default.extend-words] [tool.typos.default.extend-words]
iy = "iy" iy = "iy"
tendencias = "tendencias" tendencias = "tendencias"
indx = "indx"
# intel cpu features # intel cpu features
tme = "tme" tme = "tme"
dout = "dout" dout = "dout"
......
...@@ -69,41 +69,54 @@ def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None: ...@@ -69,41 +69,54 @@ def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
assert generated_texts[i].startswith(EXPECTED_LORA_OUTPUT[i]) assert generated_texts[i].startswith(EXPECTED_LORA_OUTPUT[i])
def test_gpt_oss_lora(gptoss20b_lora_files): @pytest.mark.parametrize("mxfp4_use_marlin", [True, False])
llm = vllm.LLM( def test_gpt_oss_lora(
MODEL_PATH, monkeypatch: pytest.MonkeyPatch, gptoss20b_lora_files, mxfp4_use_marlin
max_model_len=1024, ):
enable_lora=True, with monkeypatch.context() as m:
max_loras=4, m.setenv("VLLM_MXFP4_USE_MARLIN", "1" if mxfp4_use_marlin else "0")
max_lora_rank=8, llm = vllm.LLM(
max_num_seqs=2, MODEL_PATH,
max_num_batched_tokens=2048, max_model_len=1024,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM enable_lora=True,
cudagraph_specialize_lora=False, max_loras=4,
), max_lora_rank=8,
) max_num_seqs=2,
max_num_batched_tokens=2048,
generate_and_test(llm, gptoss20b_lora_files, lora_id=1) compilation_config=vllm.config.CompilationConfig( # Avoid OOM
generate_and_test(llm, gptoss20b_lora_files, lora_id=2) cudagraph_specialize_lora=False,
),
)
generate_and_test(llm, gptoss20b_lora_files, lora_id=1)
generate_and_test(llm, gptoss20b_lora_files, lora_id=2)
@multi_gpu_test(num_gpus=2) @multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("fully_sharded_loras", [False, True]) @pytest.mark.parametrize("fully_sharded_loras", [False, True])
def test_gpt_oss_lora_tp2(gptoss20b_lora_files, fully_sharded_loras): @pytest.mark.parametrize("mxfp4_use_marlin", [True, False])
llm = vllm.LLM( def test_gpt_oss_lora_tp2(
MODEL_PATH, monkeypatch: pytest.MonkeyPatch,
max_model_len=1024, gptoss20b_lora_files,
enable_lora=True, fully_sharded_loras,
max_loras=2, mxfp4_use_marlin,
max_num_seqs=2, ):
max_num_batched_tokens=2048, with monkeypatch.context() as m:
tensor_parallel_size=2, m.setenv("VLLM_MXFP4_USE_MARLIN", "1" if mxfp4_use_marlin else "0")
gpu_memory_utilization=0.8, llm = vllm.LLM(
fully_sharded_loras=fully_sharded_loras, MODEL_PATH,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM max_model_len=1024,
cudagraph_specialize_lora=False, enable_lora=True,
), max_loras=2,
) max_num_seqs=2,
max_num_batched_tokens=2048,
generate_and_test(llm, gptoss20b_lora_files, lora_id=1) tensor_parallel_size=2,
generate_and_test(llm, gptoss20b_lora_files, lora_id=2) gpu_memory_utilization=0.8,
fully_sharded_loras=fully_sharded_loras,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False,
),
)
generate_and_test(llm, gptoss20b_lora_files, lora_id=1)
generate_and_test(llm, gptoss20b_lora_files, lora_id=2)
...@@ -502,16 +502,18 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts): ...@@ -502,16 +502,18 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
) )
self.activation( self.activation(
activation, intermediate_cache2, intermediate_cache1.view(-1, N) activation,
intermediate_cache2,
intermediate_cache1.view(-1, N)[gather_indx.dst_indx],
) )
# matmul_ogs grouped reduction fuse sum across multiple experts: # matmul_ogs grouped reduction fuse sum across multiple experts:
# y[dst_ind // n_expts_act, :] += x[src_ind, :] # y[dst_indx // n_expts_act, :] += x
# Need to set n_expts_act to 1 to unfuse moe_sum # Need to set n_expts_act to 1 to unfuse moe_sum
routing_data.n_expts_act = 1 routing_data.n_expts_act = 1
matmul_ogs( matmul_ogs(
intermediate_cache2, intermediate_cache2[gather_indx.src_indx],
w2, w2,
self.quant_config.w2_bias, self.quant_config.w2_bias,
routing_data, routing_data,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment