Unverified Commit 5e58bdc7 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Bugfix] Remove erroneous lower bound on LoRA vocab size constraint (#35354)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
parent a1f53add
...@@ -469,7 +469,7 @@ def test_lm_head_logits_processor( ...@@ -469,7 +469,7 @@ def test_lm_head_logits_processor(
@torch.inference_mode() @torch.inference_mode()
@pytest.mark.parametrize("vocab_size", [512, 32000, 258049, 300000]) @pytest.mark.parametrize("vocab_size", [258049, 300000])
@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("device", DEVICES)
def test_lm_head_logits_processor_invalid_vocab_size( def test_lm_head_logits_processor_invalid_vocab_size(
default_vllm_config, dist_init, vocab_size, device default_vllm_config, dist_init, vocab_size, device
...@@ -489,7 +489,7 @@ def test_lm_head_logits_processor_invalid_vocab_size( ...@@ -489,7 +489,7 @@ def test_lm_head_logits_processor_invalid_vocab_size(
logits_processor, 1024, torch.float16, device, None logits_processor, 1024, torch.float16, device, None
) )
with pytest.raises(ValueError, match="vocab size must be > 32000 and <= 258048"): with pytest.raises(ValueError, match="vocab size must be <= 258048"):
lora_logits_processor.create_lora_weights(max_loras, lora_config) lora_logits_processor.create_lora_weights(max_loras, lora_config)
......
...@@ -88,10 +88,8 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): ...@@ -88,10 +88,8 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
model_config: PretrainedConfig | None = None, model_config: PretrainedConfig | None = None,
) -> None: ) -> None:
# TODO: Verify if this condition can be further relaxed # TODO: Verify if this condition can be further relaxed
if self.base_layer.vocab_size <= 32000 or self.base_layer.vocab_size > 258048: if self.base_layer.vocab_size > 258048:
raise ValueError( raise ValueError("When using LoRA, vocab size must be <= 258048")
"When using LoRA, vocab size must be > 32000 and <= 258048"
)
self.lora_a_stacked = torch.zeros( self.lora_a_stacked = torch.zeros(
( (
max_loras, max_loras,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment