"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "8fcaaf6a165e661f63fc51be906bc05b0767332f"
Commit f112086f authored by zhuwenwen's avatar zhuwenwen
Browse files

update test_moe.py

parent a75021db
...@@ -171,8 +171,8 @@ def test_mixtral_moe(dtype: torch.dtype): ...@@ -171,8 +171,8 @@ def test_mixtral_moe(dtype: torch.dtype):
for i in range(config.num_local_experts): for i in range(config.num_local_experts):
weights = (hf_moe.experts[i].w1.weight.data, weights = (hf_moe.experts[i].w1.weight.data,
hf_moe.experts[i].w3.weight.data) hf_moe.experts[i].w3.weight.data)
vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0) vllm_moe.experts.w13_weight[i][:] = (torch.cat(weights, dim=0)).T
vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data vllm_moe.experts.w2_weight[i][:] = (hf_moe.experts[i].w2.weight.data).T
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim] # Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda") hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment