Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bc4eb65b
Unverified
Commit
bc4eb65b
authored
Oct 01, 2024
by
Isotr0py
Committed by
GitHub
Oct 01, 2024
Browse files
[Bugfix] Fix Fuyu tensor parallel inference (#8986)
parent
82f3937e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
12 deletions
+15
-12
tests/distributed/test_pipeline_parallel.py
tests/distributed/test_pipeline_parallel.py
+3
-1
vllm/model_executor/models/fuyu.py
vllm/model_executor/models/fuyu.py
+2
-1
vllm/model_executor/models/persimmon.py
vllm/model_executor/models/persimmon.py
+10
-10
No files found.
tests/distributed/test_pipeline_parallel.py
View file @
bc4eb65b
...
...
@@ -37,7 +37,9 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
(
1
,
2
,
1
,
1
,
1
,
"OpenGVLab/InternVL2-1B"
,
"mp"
),
(
1
,
2
,
1
,
1
,
1
,
"OpenGVLab/InternVL2-2B"
,
"mp"
),
(
1
,
2
,
1
,
0
,
1
,
"OpenGVLab/InternVL2-4B"
,
"mp"
),
(
1
,
2
,
0
,
1
,
0
,
"Qwen/Qwen2-VL-2B-Instruct"
,
"mp"
)
(
1
,
2
,
0
,
1
,
0
,
"Qwen/Qwen2-VL-2B-Instruct"
,
"mp"
),
# TP only models
(
2
,
1
,
1
,
0
,
0
,
"adept/fuyu-8b"
,
"mp"
),
],
)
@
fork_new_process_for_each_test
...
...
vllm/model_executor/models/fuyu.py
View file @
bc4eb65b
...
...
@@ -237,8 +237,9 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
self
.
image_feature_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
gather_output
=
True
,
)
self
.
language_model
=
PersimmonForCausalLM
(
config
,
self
.
language_model
=
PersimmonForCausalLM
(
config
.
text_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
...
...
vllm/model_executor/models/persimmon.py
View file @
bc4eb65b
...
...
@@ -25,11 +25,11 @@ from typing import Iterable, List, Optional, Tuple
import
torch
from
torch
import
nn
from
transformers
import
PersimmonConfig
from
transformers.activations
import
ReLUSquaredActivation
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
...
...
@@ -57,7 +57,7 @@ class PersimmonMLP(nn.Module):
self
.
dense_4h_to_h
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
self
.
act
=
ReLUSquaredActivation
(
)
self
.
act
=
get_act_fn
(
config
.
hidden_act
,
quant_config
)
def
forward
(
self
,
hidden_states
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
dense_h_to_4h
(
hidden_states
)
...
...
@@ -96,7 +96,7 @@ class PersimmonAttention(nn.Module):
quant_config
=
quant_config
,
)
self
.
dense
=
RowParallelLinear
(
self
.
num_heads
*
self
.
head_dim
,
self
.
total_
num_heads
*
self
.
head_dim
,
self
.
hidden_size
,
bias
=
True
,
quant_config
=
quant_config
,
...
...
@@ -213,10 +213,10 @@ class PersimmonModel(nn.Module):
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
vocab_size
=
config
.
text_
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
text_config
.
vocab_size
,
config
.
hidden_size
)
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
layers
=
nn
.
ModuleList
([
PersimmonDecoderLayer
(
config
,
cache_config
=
cache_config
,
...
...
@@ -252,19 +252,19 @@ class PersimmonModel(nn.Module):
class
PersimmonForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
,
config
:
PersimmonConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
config
=
config
self
.
vocab_size
=
config
.
text_
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
self
.
model
=
PersimmonModel
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
text_config
.
vocab_size
,
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
bias
=
False
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
text_config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment