Commit 89683b9e authored by zhuwenwen's avatar zhuwenwen
Browse files

add llama nn support

parent 3e147e19
...@@ -13,6 +13,8 @@ from vllm.model_executor.parallel_utils.utils import ( ...@@ -13,6 +13,8 @@ from vllm.model_executor.parallel_utils.utils import (
divide, split_tensor_along_last_dim) divide, split_tensor_along_last_dim)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.logger import init_logger from vllm.logger import init_logger
import os
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -55,6 +57,7 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -55,6 +57,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
def __init__(self, separate_bias_add: bool = False): def __init__(self, separate_bias_add: bool = False):
self.separate_bias_add = separate_bias_add self.separate_bias_add = separate_bias_add
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def create_weights(self, input_size_per_partition: int, def create_weights(self, input_size_per_partition: int,
output_size_per_partition: int, input_size: int, output_size_per_partition: int, input_size: int,
...@@ -76,7 +79,15 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -76,7 +79,15 @@ class UnquantizedLinearMethod(LinearMethodBase):
if bias: if bias:
return F.linear(x, weight) + bias return F.linear(x, weight) + bias
return F.linear(x, weight) return F.linear(x, weight)
return F.linear(x, weight, bias)
if self.use_llama_nn:
weight = weight.reshape(weight.shape[1], -1)
if bias is not None:
return torch.matmul(x, weight) + bias
else:
return torch.matmul(x, weight)
else:
return F.linear(x, weight, bias)
class ReplicatedLinear(torch.nn.Module): class ReplicatedLinear(torch.nn.Module):
...@@ -195,6 +206,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -195,6 +206,7 @@ class ColumnParallelLinear(torch.nn.Module):
}) })
else: else:
self.register_parameter("bias", None) self.register_parameter("bias", None)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
...@@ -206,6 +218,9 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -206,6 +218,9 @@ class ColumnParallelLinear(torch.nn.Module):
loaded_weight = loaded_weight.narrow(output_dim, start_idx, loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size) shard_size)
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
if self.use_llama_nn:
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(param_data.shape[0],-1)
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
def forward(self, input_): def forward(self, input_):
...@@ -259,6 +274,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -259,6 +274,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert all(output_size % tp_size == 0 for output_size in output_sizes) assert all(output_size % tp_size == 0 for output_size in output_sizes)
super().__init__(input_size, sum(output_sizes), bias, gather_output, super().__init__(input_size, sum(output_sizes), bias, gather_output,
skip_bias_add, params_dtype, linear_method) skip_bias_add, params_dtype, linear_method)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def weight_loader(self, def weight_loader(self,
param: Parameter, param: Parameter,
...@@ -311,8 +327,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -311,8 +327,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = adjust_marlin_shard( shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset) param, shard_size, shard_offset)
param_data = param_data.narrow(output_dim, shard_offset, if self.use_llama_nn:
shard_size) param_data_ = param_data.narrow(output_dim, shard_offset,
shard_size)
else:
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
start_idx = tp_rank * shard_size start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx, loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size) shard_size)
...@@ -323,8 +343,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -323,8 +343,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
"Loading a weight without `output_dim` attribute in " "Loading a weight without `output_dim` attribute in "
"MergedColumnParallelLinear, assume the weight is " "MergedColumnParallelLinear, assume the weight is "
"the same for all partitions.") "the same for all partitions.")
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight) if self.use_llama_nn:
assert param_data_.shape == loaded_weight.shape
param_data_.copy_(loaded_weight)
if loaded_shard_id == 1 and len(param_data.shape) == 2:
param_data = param_data.transpose(0, 1)
param.data = param_data.reshape(param_data.shape[1], -1)
else:
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
class QKVParallelLinear(ColumnParallelLinear): class QKVParallelLinear(ColumnParallelLinear):
...@@ -383,6 +411,7 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -383,6 +411,7 @@ class QKVParallelLinear(ColumnParallelLinear):
2 * self.num_kv_heads) * tp_size * self.head_size 2 * self.num_kv_heads) * tp_size * self.head_size
super().__init__(input_size, output_size, bias, False, skip_bias_add, super().__init__(input_size, output_size, bias, False, skip_bias_add,
params_dtype, linear_method) params_dtype, linear_method)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def weight_loader(self, def weight_loader(self,
param: Parameter, param: Parameter,
...@@ -446,7 +475,11 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -446,7 +475,11 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = adjust_marlin_shard( shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset) param, shard_size, shard_offset)
param_data = param_data.narrow(output_dim, shard_offset, if self.use_llama_nn:
param_data_ = param_data.narrow(output_dim, shard_offset,
shard_size)
else:
param_data = param_data.narrow(output_dim, shard_offset,
shard_size) shard_size)
if loaded_shard_id == "q": if loaded_shard_id == "q":
shard_id = tp_rank shard_id = tp_rank
...@@ -462,8 +495,16 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -462,8 +495,16 @@ class QKVParallelLinear(ColumnParallelLinear):
"Loading a weight without `output_dim` attribute in " "Loading a weight without `output_dim` attribute in "
"QKVParallelLinear, assume the weight is the same " "QKVParallelLinear, assume the weight is the same "
"for all partitions.") "for all partitions.")
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight) if self.use_llama_nn:
assert param_data_.shape == loaded_weight.shape
param_data_.copy_(loaded_weight)
if loaded_shard_id == "v" and len(param_data.shape) == 2:
param_data = param_data.transpose(0, 1)
param.data = param_data.reshape(param_data.shape[1], -1)
else:
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
class RowParallelLinear(torch.nn.Module): class RowParallelLinear(torch.nn.Module):
...@@ -541,6 +582,7 @@ class RowParallelLinear(torch.nn.Module): ...@@ -541,6 +582,7 @@ class RowParallelLinear(torch.nn.Module):
}) })
else: else:
self.register_parameter("bias", None) self.register_parameter("bias", None)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
...@@ -552,6 +594,9 @@ class RowParallelLinear(torch.nn.Module): ...@@ -552,6 +594,9 @@ class RowParallelLinear(torch.nn.Module):
loaded_weight = loaded_weight.narrow(input_dim, start_idx, loaded_weight = loaded_weight.narrow(input_dim, start_idx,
shard_size) shard_size)
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
if self.use_llama_nn:
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight=loaded_weight.reshape(param_data.shape[0],-1)
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
def forward(self, input_): def forward(self, input_):
...@@ -578,4 +623,4 @@ class RowParallelLinear(torch.nn.Module): ...@@ -578,4 +623,4 @@ class RowParallelLinear(torch.nn.Module):
else: else:
output = output_ output = output_
output_bias = self.bias output_bias = self.bias
return output, output_bias return output, output_bias
\ No newline at end of file
...@@ -9,6 +9,7 @@ from vllm.config import DeviceConfig, ModelConfig ...@@ -9,6 +9,7 @@ from vllm.config import DeviceConfig, ModelConfig
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.weight_utils import (get_quant_config, from vllm.model_executor.weight_utils import (get_quant_config,
initialize_dummy_weights) initialize_dummy_weights)
import os
@contextlib.contextmanager @contextlib.contextmanager
...@@ -22,6 +23,8 @@ def _set_default_torch_dtype(dtype: torch.dtype): ...@@ -22,6 +23,8 @@ def _set_default_torch_dtype(dtype: torch.dtype):
def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]: def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
architectures = getattr(model_config.hf_config, "architectures", []) architectures = getattr(model_config.hf_config, "architectures", [])
if architectures == ['LlamaForCausalLM']:
os.environ['LLAMA_NN'] = '1'
# Special handling for quantized Mixtral. # Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack. # FIXME(woosuk): This is a temporary hack.
if (model_config.quantization is not None if (model_config.quantization is not None
...@@ -61,6 +64,9 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig, ...@@ -61,6 +64,9 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig,
f"method {model_config.quantization}. Supported dtypes: " f"method {model_config.quantization}. Supported dtypes: "
f"{supported_dtypes}") f"{supported_dtypes}")
linear_method = quant_config.get_linear_method() linear_method = quant_config.get_linear_method()
if linear_method != None:
os.environ['LLAMA_NN'] = '0'
with _set_default_torch_dtype(model_config.dtype): with _set_default_torch_dtype(model_config.dtype):
# Create a model instance. # Create a model instance.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment