Commit 0d27f0c7 authored by zhuwenwen's avatar zhuwenwen
Browse files

add linear bias

parent f26ecef8
...@@ -79,9 +79,13 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -79,9 +79,13 @@ class UnquantizedLinearMethod(LinearMethodBase):
if bias: if bias:
return F.linear(x, weight) + bias return F.linear(x, weight) + bias
return F.linear(x, weight) return F.linear(x, weight)
if self.use_llama_nn: if self.use_llama_nn:
weight = weight.reshape(weight.shape[1], -1) weight = weight.reshape(weight.shape[1], -1)
return torch.matmul(x, weight) if bias is not None:
return torch.matmul(x, weight) + bias
else:
return torch.matmul(x, weight)
else: else:
return F.linear(x, weight, bias) return F.linear(x, weight, bias)
...@@ -343,7 +347,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -343,7 +347,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if self.use_llama_nn: if self.use_llama_nn:
assert param_data_.shape == loaded_weight.shape assert param_data_.shape == loaded_weight.shape
param_data_.copy_(loaded_weight) param_data_.copy_(loaded_weight)
if loaded_shard_id == 1: if loaded_shard_id == 1 and len(param_data.shape) == 2:
param_data = param_data.transpose(0, 1) param_data = param_data.transpose(0, 1)
param.data = param_data.reshape(param_data.shape[1], -1) param.data = param_data.reshape(param_data.shape[1], -1)
else: else:
...@@ -477,7 +481,7 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -477,7 +481,7 @@ class QKVParallelLinear(ColumnParallelLinear):
else: else:
param_data = param_data.narrow(output_dim, shard_offset, param_data = param_data.narrow(output_dim, shard_offset,
shard_size) shard_size)
if loaded_shard_id == "q": if loaded_shard_id == "q" and len(param_data.shape) == 2:
shard_id = tp_rank shard_id = tp_rank
else: else:
shard_id = tp_rank // self.num_kv_head_replicas shard_id = tp_rank // self.num_kv_head_replicas
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment