Commit 89683b9e authored by zhuwenwen's avatar zhuwenwen
Browse files

add llama nn support

parent 3e147e19
......@@ -13,6 +13,8 @@ from vllm.model_executor.parallel_utils.utils import (
divide, split_tensor_along_last_dim)
from vllm.model_executor.utils import set_weight_attrs
from vllm.logger import init_logger
import os
logger = init_logger(__name__)
......@@ -55,6 +57,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
def __init__(self, separate_bias_add: bool = False):
self.separate_bias_add = separate_bias_add
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def create_weights(self, input_size_per_partition: int,
output_size_per_partition: int, input_size: int,
......@@ -76,6 +79,14 @@ class UnquantizedLinearMethod(LinearMethodBase):
if bias:
return F.linear(x, weight) + bias
return F.linear(x, weight)
if self.use_llama_nn:
weight = weight.reshape(weight.shape[1], -1)
if bias is not None:
return torch.matmul(x, weight) + bias
else:
return torch.matmul(x, weight)
else:
return F.linear(x, weight, bias)
......@@ -195,6 +206,7 @@ class ColumnParallelLinear(torch.nn.Module):
})
else:
self.register_parameter("bias", None)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
......@@ -206,6 +218,9 @@ class ColumnParallelLinear(torch.nn.Module):
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
assert param_data.shape == loaded_weight.shape
if self.use_llama_nn:
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(param_data.shape[0],-1)
param_data.copy_(loaded_weight)
def forward(self, input_):
......@@ -259,6 +274,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert all(output_size % tp_size == 0 for output_size in output_sizes)
super().__init__(input_size, sum(output_sizes), bias, gather_output,
skip_bias_add, params_dtype, linear_method)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def weight_loader(self,
param: Parameter,
......@@ -311,6 +327,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
if self.use_llama_nn:
param_data_ = param_data.narrow(output_dim, shard_offset,
shard_size)
else:
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
start_idx = tp_rank * shard_size
......@@ -323,6 +343,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
"Loading a weight without `output_dim` attribute in "
"MergedColumnParallelLinear, assume the weight is "
"the same for all partitions.")
if self.use_llama_nn:
assert param_data_.shape == loaded_weight.shape
param_data_.copy_(loaded_weight)
if loaded_shard_id == 1 and len(param_data.shape) == 2:
param_data = param_data.transpose(0, 1)
param.data = param_data.reshape(param_data.shape[1], -1)
else:
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
......@@ -383,6 +411,7 @@ class QKVParallelLinear(ColumnParallelLinear):
2 * self.num_kv_heads) * tp_size * self.head_size
super().__init__(input_size, output_size, bias, False, skip_bias_add,
params_dtype, linear_method)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def weight_loader(self,
param: Parameter,
......@@ -446,6 +475,10 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
if self.use_llama_nn:
param_data_ = param_data.narrow(output_dim, shard_offset,
shard_size)
else:
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
if loaded_shard_id == "q":
......@@ -462,6 +495,14 @@ class QKVParallelLinear(ColumnParallelLinear):
"Loading a weight without `output_dim` attribute in "
"QKVParallelLinear, assume the weight is the same "
"for all partitions.")
if self.use_llama_nn:
assert param_data_.shape == loaded_weight.shape
param_data_.copy_(loaded_weight)
if loaded_shard_id == "v" and len(param_data.shape) == 2:
param_data = param_data.transpose(0, 1)
param.data = param_data.reshape(param_data.shape[1], -1)
else:
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
......@@ -541,6 +582,7 @@ class RowParallelLinear(torch.nn.Module):
})
else:
self.register_parameter("bias", None)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
......@@ -552,6 +594,9 @@ class RowParallelLinear(torch.nn.Module):
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
shard_size)
assert param_data.shape == loaded_weight.shape
if self.use_llama_nn:
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight=loaded_weight.reshape(param_data.shape[0],-1)
param_data.copy_(loaded_weight)
def forward(self, input_):
......
......@@ -9,6 +9,7 @@ from vllm.config import DeviceConfig, ModelConfig
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.weight_utils import (get_quant_config,
initialize_dummy_weights)
import os
@contextlib.contextmanager
......@@ -22,6 +23,8 @@ def _set_default_torch_dtype(dtype: torch.dtype):
def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
architectures = getattr(model_config.hf_config, "architectures", [])
if architectures == ['LlamaForCausalLM']:
os.environ['LLAMA_NN'] = '1'
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
if (model_config.quantization is not None
......@@ -62,6 +65,9 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig,
f"{supported_dtypes}")
linear_method = quant_config.get_linear_method()
if linear_method != None:
os.environ['LLAMA_NN'] = '0'
with _set_default_torch_dtype(model_config.dtype):
# Create a model instance.
# The weights will be initialized as empty tensors.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment