Unverified Commit a11bc12d authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Fix `test_moe.py` for Transformers v5 (#33413)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 58cb55e4
...@@ -680,13 +680,21 @@ def test_mixtral_moe( ...@@ -680,13 +680,21 @@ def test_mixtral_moe(
# Load the weights # Load the weights
vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
for i in range(config.num_local_experts): if isinstance(hf_moe.experts, torch.nn.ModuleList):
weights = ( # Transformers v4
hf_moe.experts[i].w1.weight.data, for i in range(config.num_local_experts):
hf_moe.experts[i].w3.weight.data, weights = (
) hf_moe.experts[i].w1.weight.data,
vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0) hf_moe.experts[i].w3.weight.data,
vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data )
vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0)
vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
else:
# Transformers v5
vllm_moe.experts.w13_weight.data[:] = hf_moe.experts.gate_up_proj.data
vllm_moe.experts.w2_weight.data[:] = hf_moe.experts.down_proj.data
# TODO: remove this line after https://github.com/huggingface/transformers/pull/43622
hf_moe.experts.config._experts_implementation = "eager"
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim] # Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda") hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
...@@ -718,7 +726,10 @@ def test_mixtral_moe( ...@@ -718,7 +726,10 @@ def test_mixtral_moe(
get_forward_context().all_moe_layers = None get_forward_context().all_moe_layers = None
# Run forward passes for both MoE blocks # Run forward passes for both MoE blocks
hf_states, _ = hf_moe.forward(hf_inputs) hf_states = hf_moe.forward(hf_inputs)
if isinstance(hf_states, tuple):
# Transformers v4
hf_states = hf_states[0]
vllm_states = vllm_moe.forward(vllm_inputs) vllm_states = vllm_moe.forward(vllm_inputs)
mixtral_moe_tol = { mixtral_moe_tol = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment