Unverified Commit bc4eb65b authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Bugfix] Fix Fuyu tensor parallel inference (#8986)

parent 82f3937e
......@@ -37,7 +37,9 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-1B", "mp"),
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-2B", "mp"),
(1, 2, 1, 0, 1, "OpenGVLab/InternVL2-4B", "mp"),
(1, 2, 0, 1, 0, "Qwen/Qwen2-VL-2B-Instruct", "mp")
(1, 2, 0, 1, 0, "Qwen/Qwen2-VL-2B-Instruct", "mp"),
# TP only models
(2, 1, 1, 0, 0, "adept/fuyu-8b", "mp"),
],
)
@fork_new_process_for_each_test
......
......@@ -237,8 +237,9 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
self.image_feature_size,
config.hidden_size,
quant_config=quant_config,
gather_output=True,
)
self.language_model = PersimmonForCausalLM(config,
self.language_model = PersimmonForCausalLM(config.text_config,
cache_config=cache_config,
quant_config=quant_config)
......
......@@ -25,11 +25,11 @@ from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import PersimmonConfig
from transformers.activations import ReLUSquaredActivation
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
......@@ -57,7 +57,7 @@ class PersimmonMLP(nn.Module):
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size,
config.hidden_size,
quant_config=quant_config)
self.act = ReLUSquaredActivation()
self.act = get_act_fn(config.hidden_act, quant_config)
def forward(self, hidden_states) -> torch.Tensor:
hidden_states, _ = self.dense_h_to_4h(hidden_states)
......@@ -96,7 +96,7 @@ class PersimmonAttention(nn.Module):
quant_config=quant_config,
)
self.dense = RowParallelLinear(
self.num_heads * self.head_dim,
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=True,
quant_config=quant_config,
......@@ -213,10 +213,10 @@ class PersimmonModel(nn.Module):
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.vocab_size = config.text_config.vocab_size
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.text_config.vocab_size, config.hidden_size)
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.layers = nn.ModuleList([
PersimmonDecoderLayer(config,
cache_config=cache_config,
......@@ -252,19 +252,19 @@ class PersimmonModel(nn.Module):
class PersimmonForCausalLM(nn.Module):
def __init__(self,
config,
config: PersimmonConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.vocab_size = config.text_config.vocab_size
self.vocab_size = config.vocab_size
self.model = PersimmonModel(config,
cache_config=cache_config,
quant_config=quant_config)
self.lm_head = ParallelLMHead(config.text_config.vocab_size,
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
bias=False)
self.logits_processor = LogitsProcessor(config.text_config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment