Commit f26ecef8 authored by zhuwenwen's avatar zhuwenwen
Browse files

add llama_nn support

parent 96012705
......@@ -15,8 +15,8 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.logger import init_logger
import os
logger = init_logger(__name__)
USE_LLAMA_NN = int(os.environ.get('LLAMA_NN', '0')) == 1
def adjust_marlin_shard(param, shard_size, shard_offset):
......@@ -57,6 +57,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
def __init__(self, separate_bias_add: bool = False):
self.separate_bias_add = separate_bias_add
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def create_weights(self, input_size_per_partition: int,
output_size_per_partition: int, input_size: int,
......@@ -78,7 +79,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
if bias:
return F.linear(x, weight) + bias
return F.linear(x, weight)
if USE_LLAMA_NN:
if self.use_llama_nn:
weight = weight.reshape(weight.shape[1], -1)
return torch.matmul(x, weight)
else:
......@@ -201,6 +202,7 @@ class ColumnParallelLinear(torch.nn.Module):
})
else:
self.register_parameter("bias", None)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
......@@ -212,7 +214,7 @@ class ColumnParallelLinear(torch.nn.Module):
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
assert param_data.shape == loaded_weight.shape
if USE_LLAMA_NN:
if self.use_llama_nn:
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(param_data.shape[0],-1)
param_data.copy_(loaded_weight)
......@@ -268,6 +270,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert all(output_size % tp_size == 0 for output_size in output_sizes)
super().__init__(input_size, sum(output_sizes), bias, gather_output,
skip_bias_add, params_dtype, linear_method)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def weight_loader(self,
param: Parameter,
......@@ -320,7 +323,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
if USE_LLAMA_NN:
if self.use_llama_nn:
param_data_ = param_data.narrow(output_dim, shard_offset,
shard_size)
else:
......@@ -337,7 +340,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
"MergedColumnParallelLinear, assume the weight is "
"the same for all partitions.")
if USE_LLAMA_NN:
if self.use_llama_nn:
assert param_data_.shape == loaded_weight.shape
param_data_.copy_(loaded_weight)
if loaded_shard_id == 1:
......@@ -404,6 +407,7 @@ class QKVParallelLinear(ColumnParallelLinear):
2 * self.num_kv_heads) * tp_size * self.head_size
super().__init__(input_size, output_size, bias, False, skip_bias_add,
params_dtype, linear_method)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def weight_loader(self,
param: Parameter,
......@@ -467,7 +471,7 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
if USE_LLAMA_NN:
if self.use_llama_nn:
param_data_ = param_data.narrow(output_dim, shard_offset,
shard_size)
else:
......@@ -488,7 +492,7 @@ class QKVParallelLinear(ColumnParallelLinear):
"QKVParallelLinear, assume the weight is the same "
"for all partitions.")
if USE_LLAMA_NN:
if self.use_llama_nn:
assert param_data_.shape == loaded_weight.shape
param_data_.copy_(loaded_weight)
if loaded_shard_id == "v":
......@@ -574,6 +578,7 @@ class RowParallelLinear(torch.nn.Module):
})
else:
self.register_parameter("bias", None)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
......@@ -585,7 +590,7 @@ class RowParallelLinear(torch.nn.Module):
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
shard_size)
assert param_data.shape == loaded_weight.shape
if USE_LLAMA_NN:
if self.use_llama_nn:
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight=loaded_weight.reshape(param_data.shape[0],-1)
param_data.copy_(loaded_weight)
......
......@@ -9,6 +9,7 @@ from vllm.config import DeviceConfig, ModelConfig
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.weight_utils import (get_quant_config,
initialize_dummy_weights)
import os
@contextlib.contextmanager
......@@ -22,6 +23,8 @@ def _set_default_torch_dtype(dtype: torch.dtype):
def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
architectures = getattr(model_config.hf_config, "architectures", [])
if architectures == ['LlamaForCausalLM']:
os.environ['LLAMA_NN'] = '1'
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
if (model_config.quantization is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment