Unverified Commit 97b03000 authored by raywanb's avatar raywanb Committed by GitHub
Browse files

[Model] LoRA gptbigcode implementation (#3949)

parent a3a73ab0
...@@ -28,6 +28,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, ...@@ -28,6 +28,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 2752) \ f(in_T, out_T, W_T, narrow, 2752) \
f(in_T, out_T, W_T, narrow, 2816) \ f(in_T, out_T, W_T, narrow, 2816) \
f(in_T, out_T, W_T, narrow, 3072) \ f(in_T, out_T, W_T, narrow, 3072) \
f(in_T, out_T, W_T, narrow, 3328) \
f(in_T, out_T, W_T, narrow, 3456) \ f(in_T, out_T, W_T, narrow, 3456) \
f(in_T, out_T, W_T, narrow, 3584) \ f(in_T, out_T, W_T, narrow, 3584) \
f(in_T, out_T, W_T, narrow, 4096) \ f(in_T, out_T, W_T, narrow, 4096) \
...@@ -36,6 +37,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, ...@@ -36,6 +37,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 5504) \ f(in_T, out_T, W_T, narrow, 5504) \
f(in_T, out_T, W_T, narrow, 5632) \ f(in_T, out_T, W_T, narrow, 5632) \
f(in_T, out_T, W_T, narrow, 6144) \ f(in_T, out_T, W_T, narrow, 6144) \
f(in_T, out_T, W_T, narrow, 6400) \
f(in_T, out_T, W_T, narrow, 6848) \ f(in_T, out_T, W_T, narrow, 6848) \
f(in_T, out_T, W_T, narrow, 6912) \ f(in_T, out_T, W_T, narrow, 6912) \
f(in_T, out_T, W_T, narrow, 7168) \ f(in_T, out_T, W_T, narrow, 7168) \
...@@ -97,6 +99,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, ...@@ -97,6 +99,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, 2752, narrow) \ f(in_T, out_T, W_T, 2752, narrow) \
f(in_T, out_T, W_T, 2816, narrow) \ f(in_T, out_T, W_T, 2816, narrow) \
f(in_T, out_T, W_T, 3072, narrow) \ f(in_T, out_T, W_T, 3072, narrow) \
f(in_T, out_T, W_T, 3328, narrow) \
f(in_T, out_T, W_T, 3456, narrow) \ f(in_T, out_T, W_T, 3456, narrow) \
f(in_T, out_T, W_T, 3584, narrow) \ f(in_T, out_T, W_T, 3584, narrow) \
f(in_T, out_T, W_T, 4096, narrow) \ f(in_T, out_T, W_T, 4096, narrow) \
...@@ -105,6 +108,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, ...@@ -105,6 +108,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, 5504, narrow) \ f(in_T, out_T, W_T, 5504, narrow) \
f(in_T, out_T, W_T, 5632, narrow) \ f(in_T, out_T, W_T, 5632, narrow) \
f(in_T, out_T, W_T, 6144, narrow) \ f(in_T, out_T, W_T, 6144, narrow) \
f(in_T, out_T, W_T, 6400, narrow) \
f(in_T, out_T, W_T, 6848, narrow) \ f(in_T, out_T, W_T, 6848, narrow) \
f(in_T, out_T, W_T, 6912, narrow) \ f(in_T, out_T, W_T, 6912, narrow) \
f(in_T, out_T, W_T, 7168, narrow) \ f(in_T, out_T, W_T, 7168, narrow) \
......
...@@ -58,6 +58,7 @@ H1 = H2 = [ ...@@ -58,6 +58,7 @@ H1 = H2 = [
2560, 2560,
2752, 2752,
3072, 3072,
3328,
3456, 3456,
3584, 3584,
4096, 4096,
...@@ -66,6 +67,7 @@ H1 = H2 = [ ...@@ -66,6 +67,7 @@ H1 = H2 = [
5504, 5504,
5632, 5632,
6144, 6144,
6400,
6848, 6848,
6912, 6912,
7168, 7168,
......
...@@ -310,7 +310,9 @@ class LoRAModel: ...@@ -310,7 +310,9 @@ class LoRAModel:
if part_name not in expected_lora_modules: if part_name not in expected_lora_modules:
unexpected_modules.append(module) unexpected_modules.append(module)
# loaded lora's target modules must be a subset of expected_lora_modules # loaded lora's target modules must be a subset of expected_lora_modules
if unexpected_modules: if unexpected_modules:
print(unexpected_modules, "modules")
raise ValueError( raise ValueError(
f"While loading {lora_dir}, expected" f"While loading {lora_dir}, expected"
f" target modules in {expected_lora_modules}" f" target modules in {expected_lora_modules}"
......
...@@ -25,7 +25,7 @@ from torch import nn ...@@ -25,7 +25,7 @@ from torch import nn
from transformers import GPTBigCodeConfig from transformers import GPTBigCodeConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -191,14 +191,19 @@ class GPTBigCodeModel(nn.Module): ...@@ -191,14 +191,19 @@ class GPTBigCodeModel(nn.Module):
config: GPTBigCodeConfig, config: GPTBigCodeConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
assert not config.add_cross_attention assert not config.add_cross_attention
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
lora_vocab = (lora_config.lora_extra_vocab_size *
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) (lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.wte = VocabParallelEmbedding(self.vocab_size,
self.embed_dim,
org_num_embeddings=config.vocab_size)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList([ self.h = nn.ModuleList([
GPTBigCodeBlock(config, cache_config, quant_config) GPTBigCodeBlock(config, cache_config, quant_config)
...@@ -226,19 +231,35 @@ class GPTBigCodeModel(nn.Module): ...@@ -226,19 +231,35 @@ class GPTBigCodeModel(nn.Module):
class GPTBigCodeForCausalLM(nn.Module): class GPTBigCodeForCausalLM(nn.Module):
packed_modules_mapping = {"c_attn": ["c_attn"]}
supported_lora_modules = ["c_fc", "c_proj", "wte", "lm_head", "c_attn"]
embedding_modules = {
"wte": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = []
def __init__( def __init__(
self, self,
config: GPTBigCodeConfig, config: GPTBigCodeConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = GPTBigCodeModel(config, cache_config, quant_config) self.transformer = GPTBigCodeModel(config, cache_config, quant_config,
lora_config)
self.lm_head_weight = self.transformer.wte.weight self.lm_head_weight = self.transformer.wte.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
def forward( def forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment