Unverified Commit bc4eb65b authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Bugfix] Fix Fuyu tensor parallel inference (#8986)

parent 82f3937e
...@@ -37,7 +37,9 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" ...@@ -37,7 +37,9 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-1B", "mp"), (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-1B", "mp"),
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-2B", "mp"), (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-2B", "mp"),
(1, 2, 1, 0, 1, "OpenGVLab/InternVL2-4B", "mp"), (1, 2, 1, 0, 1, "OpenGVLab/InternVL2-4B", "mp"),
(1, 2, 0, 1, 0, "Qwen/Qwen2-VL-2B-Instruct", "mp") (1, 2, 0, 1, 0, "Qwen/Qwen2-VL-2B-Instruct", "mp"),
# TP only models
(2, 1, 1, 0, 0, "adept/fuyu-8b", "mp"),
], ],
) )
@fork_new_process_for_each_test @fork_new_process_for_each_test
......
...@@ -237,8 +237,9 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal): ...@@ -237,8 +237,9 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
self.image_feature_size, self.image_feature_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
gather_output=True,
) )
self.language_model = PersimmonForCausalLM(config, self.language_model = PersimmonForCausalLM(config.text_config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config)
......
...@@ -25,11 +25,11 @@ from typing import Iterable, List, Optional, Tuple ...@@ -25,11 +25,11 @@ from typing import Iterable, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import PersimmonConfig from transformers import PersimmonConfig
from transformers.activations import ReLUSquaredActivation
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
...@@ -57,7 +57,7 @@ class PersimmonMLP(nn.Module): ...@@ -57,7 +57,7 @@ class PersimmonMLP(nn.Module):
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size, self.dense_4h_to_h = RowParallelLinear(config.intermediate_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config)
self.act = ReLUSquaredActivation() self.act = get_act_fn(config.hidden_act, quant_config)
def forward(self, hidden_states) -> torch.Tensor: def forward(self, hidden_states) -> torch.Tensor:
hidden_states, _ = self.dense_h_to_4h(hidden_states) hidden_states, _ = self.dense_h_to_4h(hidden_states)
...@@ -96,7 +96,7 @@ class PersimmonAttention(nn.Module): ...@@ -96,7 +96,7 @@ class PersimmonAttention(nn.Module):
quant_config=quant_config, quant_config=quant_config,
) )
self.dense = RowParallelLinear( self.dense = RowParallelLinear(
self.num_heads * self.head_dim, self.total_num_heads * self.head_dim,
self.hidden_size, self.hidden_size,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
...@@ -213,10 +213,10 @@ class PersimmonModel(nn.Module): ...@@ -213,10 +213,10 @@ class PersimmonModel(nn.Module):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.vocab_size = config.text_config.vocab_size self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.text_config.vocab_size, config.hidden_size) config.hidden_size)
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
PersimmonDecoderLayer(config, PersimmonDecoderLayer(config,
cache_config=cache_config, cache_config=cache_config,
...@@ -252,19 +252,19 @@ class PersimmonModel(nn.Module): ...@@ -252,19 +252,19 @@ class PersimmonModel(nn.Module):
class PersimmonForCausalLM(nn.Module): class PersimmonForCausalLM(nn.Module):
def __init__(self, def __init__(self,
config, config: PersimmonConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.config = config self.config = config
self.vocab_size = config.text_config.vocab_size self.vocab_size = config.vocab_size
self.model = PersimmonModel(config, self.model = PersimmonModel(config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config)
self.lm_head = ParallelLMHead(config.text_config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
bias=False) bias=False)
self.logits_processor = LogitsProcessor(config.text_config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
def forward( def forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment