Unverified Commit 2e9a2227 authored by SangBin Cho's avatar SangBin Cho Committed by GitHub
Browse files

[Lora] Support long context lora (#4787)

Currently we need to call rotary embedding kernel for each LoRA, which makes it hard to serve multiple long context length LoRA. Add batched rotary embedding kernel and pipe it through.

It replaces the rotary embedding layer to the one that is aware of multiple cos-sin-cache per scaling factors.

Follow up of https://github.com/vllm-project/vllm/pull/3095/files
parent c0724fc9
......@@ -348,6 +348,8 @@ class ChatGLMForCausalLM(nn.Module):
super().__init__()
self.config: ChatGLMConfig = config
self.quant_config = quant_config
self.max_position_embeddings = getattr(config, "max_sequence_length",
8192)
self.transformer = ChatGLMModel(config, cache_config, quant_config)
self.lm_head_weight = self.transformer.output_layer.weight
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
......
......@@ -321,12 +321,8 @@ class LlamaForCausalLM(nn.Module):
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
"embed_tokens",
"lm_head",
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
"lm_head"
]
embedding_modules = {
"embed_tokens": "input_embeddings",
......
......@@ -46,6 +46,8 @@ class ChatGLMConfig(PretrainedConfig):
self.kv_channels = kv_channels
self.num_attention_heads = num_attention_heads
self.seq_length = seq_length
# It is to be compatible with long lora.
self.max_position_embeddings = seq_length
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.layernorm_epsilon = layernorm_epsilon
......
......@@ -34,12 +34,26 @@ class TokenizerGroup(BaseTokenizerGroup):
"""Get the maximum input length for the LoRA request."""
return self.max_input_length
def _raise_if_input_too_long(self,
encoded_tokens: List[str],
lora_request: Optional[LoRARequest] = None):
input_length = len(encoded_tokens)
if lora_request:
max_input_length = (lora_request.long_lora_max_len
or self.max_input_length)
else:
max_input_length = self.max_input_length
if max_input_length is not None and input_length > max_input_length:
raise ValueError("Input too long.", input_length, max_input_length)
def encode(self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
tokenizer = self.get_lora_tokenizer(lora_request)
return tokenizer.encode(prompt)
ret = tokenizer.encode(prompt)
self._raise_if_input_too_long(ret, lora_request)
return ret
async def encode_async(
self,
......@@ -47,7 +61,9 @@ class TokenizerGroup(BaseTokenizerGroup):
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
tokenizer = await self.get_lora_tokenizer_async(lora_request)
return tokenizer.encode(prompt)
ret = tokenizer.encode(prompt)
self._raise_if_input_too_long(ret, lora_request)
return ret
def get_lora_tokenizer(
self,
......
......@@ -156,9 +156,15 @@ class ModelRunner:
), "Model does not have embedding_padding_modules"
self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens, self.vocab_size,
self.lora_config, self.device, self.model.embedding_modules,
self.model.embedding_padding_modules)
self.scheduler_config.max_num_batched_tokens,
self.vocab_size,
self.lora_config,
self.device,
self.model.embedding_modules,
self.model.embedding_padding_modules,
max_position_embeddings=self.model.config.
max_position_embeddings,
)
self.model = self.lora_manager.create_lora_manager(self.model)
if self.kv_cache_dtype == "fp8" and is_hip():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment