Unverified Commit f84bf7d7 authored by ZT-AIA's avatar ZT-AIA Committed by GitHub
Browse files

Add Loraconfig parameter to get_punica_wrapper function (#31408)


Signed-off-by: default avatarZT-AIA <1028681969@qq.com>
Co-authored-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 99dcf5dc
...@@ -261,11 +261,11 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: ...@@ -261,11 +261,11 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
torch.set_default_device(device) torch.set_default_device(device)
max_loras = 8 max_loras = 8
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig( lora_config = LoRAConfig(
max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16 max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16
) )
punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
assert check_punica_wrapper(punica_wrapper)
def create_random_embedding_layer(): def create_random_embedding_layer():
embedding = VocabParallelEmbedding(vocab_size, 256) embedding = VocabParallelEmbedding(vocab_size, 256)
...@@ -360,11 +360,11 @@ def test_lm_head_logits_processor( ...@@ -360,11 +360,11 @@ def test_lm_head_logits_processor(
torch.set_default_device(device) torch.set_default_device(device)
max_loras = 8 max_loras = 8
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig( lora_config = LoRAConfig(
max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16 max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16
) )
punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
assert check_punica_wrapper(punica_wrapper)
def _pretest(): def _pretest():
linear = ParallelLMHead( linear = ParallelLMHead(
...@@ -480,13 +480,13 @@ def test_linear_replicated( ...@@ -480,13 +480,13 @@ def test_linear_replicated(
max_loras = 8 max_loras = 8
torch.set_default_device(device) torch.set_default_device(device)
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig( lora_config = LoRAConfig(
max_loras=max_loras, max_loras=max_loras,
max_lora_rank=8, max_lora_rank=8,
lora_dtype=torch.float16, lora_dtype=torch.float16,
) )
punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
assert check_punica_wrapper(punica_wrapper)
def create_random_linear_replicated_layer(): def create_random_linear_replicated_layer():
linear = ReplicatedLinear(4096, 4096, bias=False, params_dtype=torch.float16) linear = ReplicatedLinear(4096, 4096, bias=False, params_dtype=torch.float16)
...@@ -587,14 +587,14 @@ def test_linear_parallel( ...@@ -587,14 +587,14 @@ def test_linear_parallel(
max_loras = 8 max_loras = 8
torch.set_default_device(device) torch.set_default_device(device)
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig( lora_config = LoRAConfig(
max_loras=max_loras, max_loras=max_loras,
max_lora_rank=8, max_lora_rank=8,
fully_sharded_loras=fully_shard, fully_sharded_loras=fully_shard,
lora_dtype=torch.float16, lora_dtype=torch.float16,
) )
punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
assert check_punica_wrapper(punica_wrapper)
def create_random_linear_parallel_layer(): def create_random_linear_parallel_layer():
if orientation == "row": if orientation == "row":
...@@ -712,14 +712,14 @@ def test_column_parallel_packed( ...@@ -712,14 +712,14 @@ def test_column_parallel_packed(
max_loras = 8 max_loras = 8
torch.set_default_device(device) torch.set_default_device(device)
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig( lora_config = LoRAConfig(
max_loras=max_loras, max_loras=max_loras,
max_lora_rank=8, max_lora_rank=8,
fully_sharded_loras=fully_shard, fully_sharded_loras=fully_shard,
lora_dtype=torch.float16, lora_dtype=torch.float16,
) )
punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
assert check_punica_wrapper(punica_wrapper)
def create_column_parallel_packed_layer(): def create_column_parallel_packed_layer():
if repeats == 2: if repeats == 2:
......
...@@ -128,7 +128,7 @@ class LoRAModelManager: ...@@ -128,7 +128,7 @@ class LoRAModelManager:
max_num_batched_tokens, max_num_batched_tokens,
max_batches=self.max_num_seqs, max_batches=self.max_num_seqs,
device=self.device, device=self.device,
max_loras=self.lora_config.max_loras, lora_config=self.lora_config,
) )
self.punica_wrapper_mapping[DEFAULT_LANGUAGE_WRAPPER_KEY] = ( self.punica_wrapper_mapping[DEFAULT_LANGUAGE_WRAPPER_KEY] = (
...@@ -148,7 +148,7 @@ class LoRAModelManager: ...@@ -148,7 +148,7 @@ class LoRAModelManager:
max_num_batched_tokens, max_num_batched_tokens,
max_batches=self.max_num_seqs, max_batches=self.max_num_seqs,
device=self.device, device=self.device,
max_loras=self.lora_config.max_loras, lora_config=self.lora_config,
) )
lm_prefix = self.mm_mapping.language_model[0] lm_prefix = self.mm_mapping.language_model[0]
self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper
...@@ -186,7 +186,7 @@ class LoRAModelManager: ...@@ -186,7 +186,7 @@ class LoRAModelManager:
num_encoder_tokens, num_encoder_tokens,
max_batches=self.max_num_seqs * limit_per_prompt, max_batches=self.max_num_seqs * limit_per_prompt,
device=self.device, device=self.device,
max_loras=self.lora_config.max_loras, lora_config=self.lora_config,
) )
for prefix in self.mm_mapping.tower_model: for prefix in self.mm_mapping.tower_model:
self.punica_wrapper_mapping[prefix] = tower_punica_wrapper self.punica_wrapper_mapping[prefix] = tower_punica_wrapper
...@@ -201,7 +201,7 @@ class LoRAModelManager: ...@@ -201,7 +201,7 @@ class LoRAModelManager:
connector_tokens, connector_tokens,
max_batches=self.max_num_seqs * limit_per_prompt, max_batches=self.max_num_seqs * limit_per_prompt,
device=self.device, device=self.device,
max_loras=self.lora_config.max_loras, lora_config=self.lora_config,
) )
for prefix in self.mm_mapping.connector: for prefix in self.mm_mapping.connector:
self.punica_wrapper_mapping[prefix] = connector_punica_wrapper self.punica_wrapper_mapping[prefix] = connector_punica_wrapper
......
...@@ -45,7 +45,8 @@ class PunicaWrapperGPU(PunicaWrapperBase): ...@@ -45,7 +45,8 @@ class PunicaWrapperGPU(PunicaWrapperBase):
): ):
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device)
self.max_loras = kwargs["max_loras"] self.lora_config = kwargs["lora_config"]
self.max_loras = self.lora_config.max_loras
self.token_mapping_meta = LoRAKernelMeta.make( self.token_mapping_meta = LoRAKernelMeta.make(
self.max_loras, max_num_batched_tokens, device=device self.max_loras, max_num_batched_tokens, device=device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment