Commit 24962bed authored by zhuwenwen's avatar zhuwenwen
Browse files

fix MergedColumnParallelLinear weight_loader

parent 7f1d5aff
......@@ -930,7 +930,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offset = loaded_weight.shape[output_dim] * \
loaded_shard_id
if not envs.VLLM_USE_NN or self.is_quantization:
if not envs.VLLM_USE_NN or self.is_quantization or (envs.VLLM_USE_NN and param_data.dim()==1):
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
else:
param_data = param_data.narrow(int(not(output_dim)), shard_offset, shard_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment