Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bc4eb65b
Unverified
Commit
bc4eb65b
authored
Oct 01, 2024
by
Isotr0py
Committed by
GitHub
Oct 01, 2024
Browse files
[Bugfix] Fix Fuyu tensor parallel inference (#8986)
parent
82f3937e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
12 deletions
+15
-12
tests/distributed/test_pipeline_parallel.py
tests/distributed/test_pipeline_parallel.py
+3
-1
vllm/model_executor/models/fuyu.py
vllm/model_executor/models/fuyu.py
+2
-1
vllm/model_executor/models/persimmon.py
vllm/model_executor/models/persimmon.py
+10
-10
No files found.
tests/distributed/test_pipeline_parallel.py
View file @
bc4eb65b
...
@@ -37,7 +37,9 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
...
@@ -37,7 +37,9 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
(
1
,
2
,
1
,
1
,
1
,
"OpenGVLab/InternVL2-1B"
,
"mp"
),
(
1
,
2
,
1
,
1
,
1
,
"OpenGVLab/InternVL2-1B"
,
"mp"
),
(
1
,
2
,
1
,
1
,
1
,
"OpenGVLab/InternVL2-2B"
,
"mp"
),
(
1
,
2
,
1
,
1
,
1
,
"OpenGVLab/InternVL2-2B"
,
"mp"
),
(
1
,
2
,
1
,
0
,
1
,
"OpenGVLab/InternVL2-4B"
,
"mp"
),
(
1
,
2
,
1
,
0
,
1
,
"OpenGVLab/InternVL2-4B"
,
"mp"
),
(
1
,
2
,
0
,
1
,
0
,
"Qwen/Qwen2-VL-2B-Instruct"
,
"mp"
)
(
1
,
2
,
0
,
1
,
0
,
"Qwen/Qwen2-VL-2B-Instruct"
,
"mp"
),
# TP only models
(
2
,
1
,
1
,
0
,
0
,
"adept/fuyu-8b"
,
"mp"
),
],
],
)
)
@
fork_new_process_for_each_test
@
fork_new_process_for_each_test
...
...
vllm/model_executor/models/fuyu.py
View file @
bc4eb65b
...
@@ -237,8 +237,9 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
...
@@ -237,8 +237,9 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
self
.
image_feature_size
,
self
.
image_feature_size
,
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
gather_output
=
True
,
)
)
self
.
language_model
=
PersimmonForCausalLM
(
config
,
self
.
language_model
=
PersimmonForCausalLM
(
config
.
text_config
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
...
...
vllm/model_executor/models/persimmon.py
View file @
bc4eb65b
...
@@ -25,11 +25,11 @@ from typing import Iterable, List, Optional, Tuple
...
@@ -25,11 +25,11 @@ from typing import Iterable, List, Optional, Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PersimmonConfig
from
transformers
import
PersimmonConfig
from
transformers.activations
import
ReLUSquaredActivation
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
...
@@ -57,7 +57,7 @@ class PersimmonMLP(nn.Module):
...
@@ -57,7 +57,7 @@ class PersimmonMLP(nn.Module):
self
.
dense_4h_to_h
=
RowParallelLinear
(
config
.
intermediate_size
,
self
.
dense_4h_to_h
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
self
.
act
=
ReLUSquaredActivation
(
)
self
.
act
=
get_act_fn
(
config
.
hidden_act
,
quant_config
)
def
forward
(
self
,
hidden_states
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
dense_h_to_4h
(
hidden_states
)
hidden_states
,
_
=
self
.
dense_h_to_4h
(
hidden_states
)
...
@@ -96,7 +96,7 @@ class PersimmonAttention(nn.Module):
...
@@ -96,7 +96,7 @@ class PersimmonAttention(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
)
)
self
.
dense
=
RowParallelLinear
(
self
.
dense
=
RowParallelLinear
(
self
.
num_heads
*
self
.
head_dim
,
self
.
total_
num_heads
*
self
.
head_dim
,
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
bias
=
True
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
...
@@ -213,10 +213,10 @@ class PersimmonModel(nn.Module):
...
@@ -213,10 +213,10 @@ class PersimmonModel(nn.Module):
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
vocab_size
=
config
.
text_
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
text_config
.
vocab_size
,
config
.
hidden_size
)
config
.
hidden_size
)
self
.
layers
=
nn
.
ModuleList
([
self
.
layers
=
nn
.
ModuleList
([
PersimmonDecoderLayer
(
config
,
PersimmonDecoderLayer
(
config
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
...
@@ -252,19 +252,19 @@ class PersimmonModel(nn.Module):
...
@@ -252,19 +252,19 @@ class PersimmonModel(nn.Module):
class
PersimmonForCausalLM
(
nn
.
Module
):
class
PersimmonForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
config
,
config
:
PersimmonConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
vocab_size
=
config
.
text_
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
self
.
model
=
PersimmonModel
(
config
,
self
.
model
=
PersimmonModel
(
config
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
text_config
.
vocab_size
,
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
bias
=
False
)
bias
=
False
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
text_config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment