"vllm/vscode:/vscode.git/clone" did not exist on "6366c098d7c76120b6a55a6829a2649c727a2862"
Unverified Commit 2e9a2227 authored by SangBin Cho's avatar SangBin Cho Committed by GitHub
Browse files

[Lora] Support long context lora (#4787)

Currently we need to call rotary embedding kernel for each LoRA, which makes it hard to serve multiple long context length LoRA. Add batched rotary embedding kernel and pipe it through.

It replaces the rotary embedding layer to the one that is aware of multiple cos-sin-cache per scaling factors.

Follow up of https://github.com/vllm-project/vllm/pull/3095/files
parent c0724fc9
...@@ -348,6 +348,8 @@ class ChatGLMForCausalLM(nn.Module): ...@@ -348,6 +348,8 @@ class ChatGLMForCausalLM(nn.Module):
super().__init__() super().__init__()
self.config: ChatGLMConfig = config self.config: ChatGLMConfig = config
self.quant_config = quant_config self.quant_config = quant_config
self.max_position_embeddings = getattr(config, "max_sequence_length",
8192)
self.transformer = ChatGLMModel(config, cache_config, quant_config) self.transformer = ChatGLMModel(config, cache_config, quant_config)
self.lm_head_weight = self.transformer.output_layer.weight self.lm_head_weight = self.transformer.output_layer.weight
self.logits_processor = LogitsProcessor(config.padded_vocab_size) self.logits_processor = LogitsProcessor(config.padded_vocab_size)
......
...@@ -321,12 +321,8 @@ class LlamaForCausalLM(nn.Module): ...@@ -321,12 +321,8 @@ class LlamaForCausalLM(nn.Module):
# LoRA specific attributes # LoRA specific attributes
supported_lora_modules = [ supported_lora_modules = [
"qkv_proj", "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
"o_proj", "lm_head"
"gate_up_proj",
"down_proj",
"embed_tokens",
"lm_head",
] ]
embedding_modules = { embedding_modules = {
"embed_tokens": "input_embeddings", "embed_tokens": "input_embeddings",
......
...@@ -46,6 +46,8 @@ class ChatGLMConfig(PretrainedConfig): ...@@ -46,6 +46,8 @@ class ChatGLMConfig(PretrainedConfig):
self.kv_channels = kv_channels self.kv_channels = kv_channels
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.seq_length = seq_length self.seq_length = seq_length
# It is to be compatible with long lora.
self.max_position_embeddings = seq_length
self.hidden_dropout = hidden_dropout self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.layernorm_epsilon = layernorm_epsilon self.layernorm_epsilon = layernorm_epsilon
......
...@@ -34,12 +34,26 @@ class TokenizerGroup(BaseTokenizerGroup): ...@@ -34,12 +34,26 @@ class TokenizerGroup(BaseTokenizerGroup):
"""Get the maximum input length for the LoRA request.""" """Get the maximum input length for the LoRA request."""
return self.max_input_length return self.max_input_length
def _raise_if_input_too_long(self,
encoded_tokens: List[str],
lora_request: Optional[LoRARequest] = None):
input_length = len(encoded_tokens)
if lora_request:
max_input_length = (lora_request.long_lora_max_len
or self.max_input_length)
else:
max_input_length = self.max_input_length
if max_input_length is not None and input_length > max_input_length:
raise ValueError("Input too long.", input_length, max_input_length)
def encode(self, def encode(self,
prompt: str, prompt: str,
request_id: Optional[str] = None, request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]: lora_request: Optional[LoRARequest] = None) -> List[int]:
tokenizer = self.get_lora_tokenizer(lora_request) tokenizer = self.get_lora_tokenizer(lora_request)
return tokenizer.encode(prompt) ret = tokenizer.encode(prompt)
self._raise_if_input_too_long(ret, lora_request)
return ret
async def encode_async( async def encode_async(
self, self,
...@@ -47,7 +61,9 @@ class TokenizerGroup(BaseTokenizerGroup): ...@@ -47,7 +61,9 @@ class TokenizerGroup(BaseTokenizerGroup):
request_id: Optional[str] = None, request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]: lora_request: Optional[LoRARequest] = None) -> List[int]:
tokenizer = await self.get_lora_tokenizer_async(lora_request) tokenizer = await self.get_lora_tokenizer_async(lora_request)
return tokenizer.encode(prompt) ret = tokenizer.encode(prompt)
self._raise_if_input_too_long(ret, lora_request)
return ret
def get_lora_tokenizer( def get_lora_tokenizer(
self, self,
......
...@@ -156,9 +156,15 @@ class ModelRunner: ...@@ -156,9 +156,15 @@ class ModelRunner:
), "Model does not have embedding_padding_modules" ), "Model does not have embedding_padding_modules"
self.lora_manager = LRUCacheWorkerLoRAManager( self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens, self.vocab_size, self.scheduler_config.max_num_batched_tokens,
self.lora_config, self.device, self.model.embedding_modules, self.vocab_size,
self.model.embedding_padding_modules) self.lora_config,
self.device,
self.model.embedding_modules,
self.model.embedding_padding_modules,
max_position_embeddings=self.model.config.
max_position_embeddings,
)
self.model = self.lora_manager.create_lora_manager(self.model) self.model = self.lora_manager.create_lora_manager(self.model)
if self.kv_cache_dtype == "fp8" and is_hip(): if self.kv_cache_dtype == "fp8" and is_hip():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment