Commit 8aa30111 authored by zhuwenwen's avatar zhuwenwen
Browse files

add support of llama_nn

parent 055b67ee
...@@ -13,8 +13,10 @@ from vllm.model_executor.parallel_utils.utils import ( ...@@ -13,8 +13,10 @@ from vllm.model_executor.parallel_utils.utils import (
divide, split_tensor_along_last_dim) divide, split_tensor_along_last_dim)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.logger import init_logger from vllm.logger import init_logger
import os
logger = init_logger(__name__) logger = init_logger(__name__)
USE_LLAMA_NN = int(os.environ.get('LLAMA_NN', '0')) == 1
def adjust_marlin_shard(param, shard_size, shard_offset): def adjust_marlin_shard(param, shard_size, shard_offset):
...@@ -76,7 +78,11 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -76,7 +78,11 @@ class UnquantizedLinearMethod(LinearMethodBase):
if bias: if bias:
return F.linear(x, weight) + bias return F.linear(x, weight) + bias
return F.linear(x, weight) return F.linear(x, weight)
return F.linear(x, weight, bias) if USE_LLAMA_NN:
weight = weight.reshape(weight.shape[1], -1)
return torch.matmul(x, weight)
else:
return F.linear(x, weight, bias)
class ReplicatedLinear(torch.nn.Module): class ReplicatedLinear(torch.nn.Module):
...@@ -206,6 +212,9 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -206,6 +212,9 @@ class ColumnParallelLinear(torch.nn.Module):
loaded_weight = loaded_weight.narrow(output_dim, start_idx, loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size) shard_size)
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
if USE_LLAMA_NN:
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(param_data.shape[0],-1)
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
def forward(self, input_): def forward(self, input_):
...@@ -311,8 +320,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -311,8 +320,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = adjust_marlin_shard( shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset) param, shard_size, shard_offset)
param_data = param_data.narrow(output_dim, shard_offset, if USE_LLAMA_NN:
shard_size) param_data_ = param_data.narrow(output_dim, shard_offset,
shard_size)
else:
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
start_idx = tp_rank * shard_size start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx, loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size) shard_size)
...@@ -323,8 +336,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -323,8 +336,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
"Loading a weight without `output_dim` attribute in " "Loading a weight without `output_dim` attribute in "
"MergedColumnParallelLinear, assume the weight is " "MergedColumnParallelLinear, assume the weight is "
"the same for all partitions.") "the same for all partitions.")
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight) if USE_LLAMA_NN:
assert param_data_.shape == loaded_weight.shape
param_data_.copy_(loaded_weight)
if loaded_shard_id == 1:
param_data = param_data.transpose(0, 1)
param.data = param_data.reshape(param_data.shape[1], -1)
else:
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
class QKVParallelLinear(ColumnParallelLinear): class QKVParallelLinear(ColumnParallelLinear):
...@@ -446,7 +467,11 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -446,7 +467,11 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = adjust_marlin_shard( shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset) param, shard_size, shard_offset)
param_data = param_data.narrow(output_dim, shard_offset, if USE_LLAMA_NN:
param_data_ = param_data.narrow(output_dim, shard_offset,
shard_size)
else:
param_data = param_data.narrow(output_dim, shard_offset,
shard_size) shard_size)
if loaded_shard_id == "q": if loaded_shard_id == "q":
shard_id = tp_rank shard_id = tp_rank
...@@ -462,8 +487,16 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -462,8 +487,16 @@ class QKVParallelLinear(ColumnParallelLinear):
"Loading a weight without `output_dim` attribute in " "Loading a weight without `output_dim` attribute in "
"QKVParallelLinear, assume the weight is the same " "QKVParallelLinear, assume the weight is the same "
"for all partitions.") "for all partitions.")
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight) if USE_LLAMA_NN:
assert param_data_.shape == loaded_weight.shape
param_data_.copy_(loaded_weight)
if loaded_shard_id == "v":
param_data = param_data.transpose(0, 1)
param.data = param_data.reshape(param_data.shape[1], -1)
else:
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
class RowParallelLinear(torch.nn.Module): class RowParallelLinear(torch.nn.Module):
...@@ -552,6 +585,9 @@ class RowParallelLinear(torch.nn.Module): ...@@ -552,6 +585,9 @@ class RowParallelLinear(torch.nn.Module):
loaded_weight = loaded_weight.narrow(input_dim, start_idx, loaded_weight = loaded_weight.narrow(input_dim, start_idx,
shard_size) shard_size)
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
if USE_LLAMA_NN:
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight=loaded_weight.reshape(param_data.shape[0],-1)
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
def forward(self, input_): def forward(self, input_):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment